import torch


def can_use_flash_attn(device_id=0):
    """Check if a GPU supports FlashAttention."""
    major, minor = torch.cuda.get_device_capability(device_id)
    device_full_name = torch.cuda.get_device_name(device_id)
    device_name = device_full_name.split()[-1]

    # Check if the GPU architecture is Ampere (SM 8.x) or newer (SM 9.0)
    is_sm8x = major == 8 and minor >= 0
    is_sm90 = major == 9 and minor == 0
    if 'MI308X' in device_name:
       is_sm90 = major == 9 and minor >= 0

    return is_sm8x or is_sm90